from typing import Optional

import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Subset, Dataset


class TransferDataModule(LightningDataModule):
    """Standard MNIST, train, val, test splits and transforms.
    >>> MNISTDataModule()  # doctest: +ELLIPSIS
    <...mnist_datamodule.MNISTDataModule object at ...>
    """

    name = "transfer"

    def __init__(
            self, num_training_steps: int,
            z_vector_size: int,
            learner_train_indices: [int],
            learner_val_indices: [int],
            datamodule: LightningDataModule,
            no_noise: bool,
            teacher_train_indices: None,  # used for debugging
            batch_size=None
    ):
        """
        Args:
        """
        super().__init__()
        self.datamodule = datamodule
        self.z_vector_size = z_vector_size
        self.num_training_steps = num_training_steps
        self.learner_val_indices = list(learner_val_indices)
        self.learner_train_indices = list(learner_train_indices)
        self.no_noise = no_noise
        self.teacher_train_indices = teacher_train_indices
        self.batch_size = batch_size if batch_size else self.datamodule.batch_size

    @property
    def num_classes(self):
        return 10

    def prepare_data(self):
        self.datamodule.prepare_data()  # to download the data

    def setup(self, stage: Optional[str] = None):
        self.datamodule.transfer_setup()
        self.train_dataset = Subset(
            self.datamodule.train_dataset, indices=self.learner_train_indices
        )
        if self.teacher_train_indices:
            self.train_dataset = Subset(
                self.datamodule.train_dataset, indices=[*self.learner_train_indices, *self.teacher_train_indices]
            )
        self.val_dataset = Subset(
            self.datamodule.val_dataset, indices=self.learner_val_indices
        )
        self.test_dataset = self.datamodule.test_dataset

    def train_dataloader(self):
        loader = DataLoader(
            dataset=self.train_dataset if self.no_noise else NoiseDataset(
                noise_vector_size=self.z_vector_size,
                dataset_size=self.num_training_steps * self.batch_size,
                train_dataset=self.train_dataset,
            ),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.datamodule.num_workers,
            drop_last=True,  # wont happen at all since size is based on batch size
            pin_memory=False,
        )
        return loader

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.datamodule.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.datamodule.num_workers,
            drop_last=False,
            pin_memory=True,
        )


class NoiseDataset(Dataset):
    def __init__(self, noise_vector_size: int, dataset_size: int, train_dataset: Optional[Subset]):
        self.noise_vector_size = noise_vector_size
        self.dataset_size = dataset_size
        self.train_dataset = train_dataset
        self.current_idx = 0
        self.train_length = len(self.train_dataset.indices)
        self.sampling_order = np.arange(self.train_length)


    def __len__(self):
        return self.dataset_size

    def __getitem__(self, index):
        if self.current_idx == 0:
            np.random.shuffle(self.sampling_order)

        selected_idx = self.train_dataset.indices[self.sampling_order[self.current_idx]]
        x, y = self.train_dataset.dataset[selected_idx]

        self.current_idx = (self.current_idx + 1) % self.train_length

        return torch.randn(self.noise_vector_size), x, y
